# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# JAX is Autograd and XLA

load("@bazel_skylib//lib:selects.bzl", "selects")
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag", "string_flag")
load("@cuda_cudart//:version.bzl", cuda_major_version = "VERSION")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
load("@rules_python//python:py_binary.bzl", "py_binary")
load(
    "@xla//third_party/py:py_import.bzl",
    "py_import",
)
load(
    "@xla//third_party/py:py_manylinux_compliance_test.bzl",
    "verify_manylinux_compliance_test",
)
load(
    "//jaxlib:jax.bzl",
    "PLATFORM_TAGS_DICT",
    "if_pypi_cuda_wheel_deps",
    "jax_py_test",
    "jax_wheel",
    "pytype_strict_library",
    "pytype_test",
    "wheel_sources",
)

licenses(["notice"])  # Apache 2

package(default_visibility = ["//visibility:public"])

exports_files(["wheel_size_test.py"])

genrule(
    name = "platform_tags_py",
    srcs = [],
    outs = ["platform_tags.py"],
    cmd = "echo 'PLATFORM_TAGS_DICT = %s' > $@;" % PLATFORM_TAGS_DICT,
)

pytype_strict_library(
    name = "build_utils",
    srcs = [
        "build_utils.py",
        ":platform_tags_py",
    ],
    deps = ["@xla//third_party/py:setup_py_nvidia_dependencies_util"],
)

# Platform configurations.

selects.config_setting_group(
    name = "macos",
    match_any = [
        "@platforms//os:osx",
        "@platforms//os:macos",
    ],
)

selects.config_setting_group(
    name = "arm64",
    match_any = [
        "@platforms//cpu:aarch64",
        "@platforms//cpu:arm64",
    ],
)

selects.config_setting_group(
    name = "macos_arm64",
    match_all = [
        ":arm64",
        ":macos",
    ],
)

selects.config_setting_group(
    name = "macos_x86_64",
    match_all = [
        "@platforms//cpu:x86_64",
        ":macos",
    ],
)

selects.config_setting_group(
    name = "win_amd64",
    match_all = [
        "@platforms//cpu:x86_64",
        "@platforms//os:windows",
    ],
)

selects.config_setting_group(
    name = "linux_x86_64",
    match_all = [
        "@platforms//cpu:x86_64",
        "@platforms//os:linux",
    ],
)

selects.config_setting_group(
    name = "linux_aarch64",
    match_all = [
        ":arm64",
        "@platforms//os:linux",
    ],
)

# Flags for the new wheel build rules.

string_flag(
    name = "jaxlib_git_hash",
    build_setting_default = "",
)

string_flag(
    name = "output_path",
    build_setting_default = "dist",
)

# Wheel targets.

# Jaxlib wheel targets.
py_binary(
    name = "build_wheel_tool",
    srcs = ["build_wheel.py"],
    main = "build_wheel.py",
    deps = [
        ":build_utils",
        "@bazel_tools//tools/python/runfiles",
        "@pypi//build",
        "@pypi//setuptools",
        "@pypi//wheel",
    ],
)

wheel_sources(
    name = "jaxlib_sources",
    data_srcs = [
        "//jaxlib",
        "//jaxlib:jaxlib_binaries",
        "//jaxlib:_jax",
    ],
    hdr_srcs = [
        "@xla//xla/ffi/api:ffi",
    ],
    py_srcs = [
        "//jaxlib",
    ],
    static_srcs = [
        "//jaxlib:README.md",
        "LICENSE.txt",
        "//jaxlib:setup.py",
        "//jaxlib:xla_client.py",
    ],
    symlink_data_srcs = [
        "//jaxlib",
    ],
)

jax_wheel(
    name = "jaxlib_wheel",
    no_abi = False,
    source_files = [":jaxlib_sources"],
    wheel_binary = ":build_wheel_tool",
    wheel_name = "jaxlib",
)

jax_wheel(
    name = "jaxlib_wheel_editable",
    editable = True,
    source_files = [":jaxlib_sources"],
    wheel_binary = ":build_wheel_tool",
    wheel_name = "jaxlib",
)

# JAX plugin wheel targets.
pytype_strict_library(
    name = "version",
    srcs = ["//jaxlib:version"],
)

py_binary(
    name = "build_gpu_kernels_wheel_tool",
    srcs = ["build_gpu_kernels_wheel.py"],
    main = "build_gpu_kernels_wheel.py",
    deps = [
        ":build_utils",
        "@bazel_tools//tools/python/runfiles",
        "@pypi//build",
        "@pypi//setuptools",
        "@pypi//wheel",
    ],
)

wheel_sources(
    name = "jax_plugin_sources",
    data_srcs = [
    ] + if_cuda([
        "//jaxlib/cuda:cuda_gpu_support",
        "@local_config_cuda//cuda:cuda-nvvm",
        "//jaxlib/cuda:cuda_plugin_extension",
        "//jaxlib/mosaic/gpu:mosaic_gpu",
    ]) + if_rocm([
        "//jaxlib/rocm:rocm_gpu_support",
        "//jaxlib/rocm:rocm_plugin_extension",
    ]),
    py_srcs = [":version"] + if_cuda([
        "//jaxlib/cuda:cuda_gpu_support",
        "//jaxlib/mosaic/gpu:mosaic_gpu",
    ]) + if_rocm([
        "//jaxlib/rocm:rocm_gpu_support",
    ]),
    static_srcs = [
        "LICENSE.txt",
    ] + if_cuda([
        "//jax_plugins/cuda:plugin_pyproject.toml",
        "//jax_plugins/cuda:plugin_setup.py",
    ]) + if_rocm([
        "//jax_plugins/rocm:plugin_pyproject.toml",
        "//jax_plugins/rocm:plugin_setup.py",
    ]),
)

jax_wheel(
    name = "jax_cuda12_plugin_wheel",
    enable_cuda = True,
    no_abi = False,
    # TODO(b/371217563) May use hermetic cuda version here.
    platform_version = "12",
    source_files = [":jax_plugin_sources"],
    wheel_binary = ":build_gpu_kernels_wheel_tool",
    wheel_name = "jax_cuda12_plugin",
)

jax_wheel(
    name = "jax_cuda13_plugin_wheel",
    enable_cuda = True,
    no_abi = False,
    # TODO(b/371217563) May use hermetic cuda version here.
    platform_version = "13",
    source_files = [":jax_plugin_sources"],
    wheel_binary = ":build_gpu_kernels_wheel_tool",
    wheel_name = "jax_cuda13_plugin",
)

jax_wheel(
    name = "jax_cuda12_plugin_wheel_editable",
    editable = True,
    enable_cuda = True,
    # TODO(b/371217563) May use hermetic cuda version here.
    platform_version = "12",
    source_files = [":jax_plugin_sources"],
    wheel_binary = ":build_gpu_kernels_wheel_tool",
    wheel_name = "jax_cuda12_plugin",
)

jax_wheel(
    name = "jax_cuda13_plugin_wheel_editable",
    editable = True,
    enable_cuda = True,
    # TODO(b/371217563) May use hermetic cuda version here.
    platform_version = "13",
    source_files = [":jax_plugin_sources"],
    wheel_binary = ":build_gpu_kernels_wheel_tool",
    wheel_name = "jax_cuda13_plugin",
)

jax_wheel(
    name = "jax_rocm_plugin_wheel",
    enable_rocm = True,
    no_abi = False,
    platform_version = "60",
    source_files = [":jax_plugin_sources"],
    wheel_binary = ":build_gpu_kernels_wheel_tool",
    wheel_name = "jax_rocm60_plugin",
)

jax_wheel(
    name = "jax_rocm_plugin_wheel_editable",
    editable = True,
    enable_rocm = True,
    platform_version = "60",
    source_files = [":jax_plugin_sources"],
    wheel_binary = ":build_gpu_kernels_wheel_tool",
    wheel_name = "jax_rocm60_plugin",
)

# JAX PJRT wheel targets.

py_binary(
    name = "build_gpu_plugin_wheel_tool",
    srcs = ["build_gpu_plugin_wheel.py"],
    main = "build_gpu_plugin_wheel.py",
    deps = [
        ":build_utils",
        "@bazel_tools//tools/python/runfiles",
        "@pypi//build",
        "@pypi//setuptools",
        "@pypi//wheel",
    ],
)

wheel_sources(
    name = "jax_pjrt_sources",
    data_srcs = if_cuda([
        "//jax_plugins/cuda:cuda_plugin",
        "//jaxlib/cuda:cuda_gpu_support",
        "@local_config_cuda//cuda:cuda-nvvm",
    ]) + if_rocm([
        "//jax_plugins/rocm:rocm_plugin",
        "//jaxlib/rocm:rocm_gpu_support",
    ]),
    py_srcs = [
        ":version",
    ] + if_cuda([
        "//jaxlib/cuda:cuda_gpu_support",
    ]) + if_rocm([
        "//jaxlib/rocm:rocm_gpu_support",
    ]),
    static_srcs = [
        "LICENSE.txt",
    ] + if_cuda([
        "//jax_plugins/cuda:pyproject.toml",
        "//jax_plugins/cuda:setup.py",
        "//jax_plugins/cuda:__init__.py",
    ]) + if_rocm([
        "//jax_plugins/rocm:pyproject.toml",
        "//jax_plugins/rocm:setup.py",
        "//jax_plugins/rocm:__init__.py",
    ]),
)

jax_wheel(
    name = "jax_cuda12_pjrt_wheel",
    enable_cuda = True,
    no_abi = True,
    # TODO(b/371217563) May use hermetic cuda version here.
    platform_version = "12",
    source_files = [":jax_pjrt_sources"],
    wheel_binary = ":build_gpu_plugin_wheel_tool",
    wheel_name = "jax_cuda12_pjrt",
)

jax_wheel(
    name = "jax_cuda13_pjrt_wheel",
    enable_cuda = True,
    no_abi = True,
    # TODO(b/371217563) May use hermetic cuda version here.
    platform_version = "13",
    source_files = [":jax_pjrt_sources"],
    wheel_binary = ":build_gpu_plugin_wheel_tool",
    wheel_name = "jax_cuda13_pjrt",
)

jax_wheel(
    name = "jax_cuda12_pjrt_wheel_editable",
    editable = True,
    enable_cuda = True,
    # TODO(b/371217563) May use hermetic cuda version here.
    platform_version = "12",
    source_files = [":jax_pjrt_sources"],
    wheel_binary = ":build_gpu_plugin_wheel_tool",
    wheel_name = "jax_cuda12_pjrt",
)

jax_wheel(
    name = "jax_cuda13_pjrt_wheel_editable",
    editable = True,
    enable_cuda = True,
    # TODO(b/371217563) May use hermetic cuda version here.
    platform_version = "13",
    source_files = [":jax_pjrt_sources"],
    wheel_binary = ":build_gpu_plugin_wheel_tool",
    wheel_name = "jax_cuda13_pjrt",
)

jax_wheel(
    name = "jax_rocm_pjrt_wheel",
    enable_rocm = True,
    no_abi = True,
    platform_version = "60",
    source_files = [":jax_pjrt_sources"],
    wheel_binary = ":build_gpu_plugin_wheel_tool",
    wheel_name = "jax_rocm60_pjrt",
)

jax_wheel(
    name = "jax_rocm_pjrt_wheel_editable",
    editable = True,
    enable_rocm = True,
    platform_version = "60",
    source_files = [":jax_pjrt_sources"],
    wheel_binary = ":build_gpu_plugin_wheel_tool",
    wheel_name = "jax_rocm60_pjrt",
)

# Py_import targets.
cuda_suffix = "_cu12" if cuda_major_version == "12" else ""

filegroup(
    name = "nvidia_wheel_deps",
    srcs = [
        "@pypi_nvidia_cublas{cuda}//:pkg".format(cuda = cuda_suffix),
        "@pypi_nvidia_cuda_cupti{cuda}//:pkg".format(cuda = cuda_suffix),
        "@pypi_nvidia_cuda_nvcc{cuda}//:pkg".format(cuda = cuda_suffix),
        "@pypi_nvidia_cuda_nvrtc{cuda}//:pkg".format(cuda = cuda_suffix),
        "@pypi_nvidia_cuda_runtime{cuda}//:pkg".format(cuda = cuda_suffix),
        "@pypi_nvidia_cudnn_cu{cuda}//:pkg".format(cuda = cuda_major_version),
        "@pypi_nvidia_cufft{cuda}//:pkg".format(cuda = cuda_suffix),
        "@pypi_nvidia_cusolver{cuda}//:pkg".format(cuda = cuda_suffix),
        "@pypi_nvidia_cusparse{cuda}//:pkg".format(cuda = cuda_suffix),
        "@pypi_nvidia_nccl_cu{cuda}//:pkg".format(cuda = cuda_major_version),
        "@pypi_nvidia_nvjitlink{cuda}//:pkg".format(cuda = cuda_suffix),
        "@pypi_nvidia_nvshmem_cu{cuda}//:pkg".format(cuda = cuda_major_version),
    ] + ([
        "@pypi_nvidia_nvvm{cuda}//:pkg".format(cuda = cuda_suffix),
        "@pypi_nvidia_cuda_crt{cuda}//:pkg".format(cuda = cuda_suffix),
    ] if cuda_major_version == "13" else []),
)

# The flag configures whether to add the pypi NVIDIA CUDA deps to py_import.
bool_flag(
    name = "add_pypi_cuda_wheel_deps",
    build_setting_default = True,
)

config_setting(
    name = "pypi_cuda_wheel_deps",
    flag_values = {
        ":add_pypi_cuda_wheel_deps": "True",
        "@local_config_cuda//:enable_cuda": "True",
    },
)

py_import(
    name = "jaxlib_py_import",
    wheel = ":jaxlib_wheel",
)

py_import(
    name = "jax_cuda_plugin_py_import",
    wheel = ":jax_cuda{cuda}_plugin_wheel".format(cuda = cuda_major_version),
    wheel_deps = if_pypi_cuda_wheel_deps([":nvidia_wheel_deps"]),
)

py_import(
    name = "jax_cuda_pjrt_py_import",
    wheel = ":jax_cuda{cuda}_pjrt_wheel".format(cuda = cuda_major_version),
    wheel_deps = if_pypi_cuda_wheel_deps([":nvidia_wheel_deps"]),
)

# The targets below are used for GPU tests with `--//jax:build_jaxlib=false`.
py_import(
    name = "pypi_jax_cuda_plugin_with_cuda_deps",
    wheel = "@pypi_jax_cuda{cuda}_plugin//:whl".format(cuda = cuda_major_version),
    wheel_deps = if_pypi_cuda_wheel_deps([":nvidia_wheel_deps"]),
)

py_import(
    name = "pypi_jax_cuda_pjrt_with_cuda_deps",
    wheel = "@pypi_jax_cuda{cuda}_pjrt//:whl".format(cuda = cuda_major_version),
    wheel_deps = if_pypi_cuda_wheel_deps([":nvidia_wheel_deps"]),
)

# Mosaic GPU

py_binary(
    name = "build_mosaic_wheel_tool",
    srcs = ["build_mosaic_wheel.py"],
    main = "build_mosaic_wheel.py",
    deps = [
        ":build_utils",
        "@bazel_tools//tools/python/runfiles",
        "@pypi//build",
        "@pypi//setuptools",
        "@pypi//wheel",
    ],
)

wheel_sources(
    name = "mosaic_sources",
    static_srcs = [
        "LICENSE.txt",
        "//jaxlib/mosaic/gpu/wheel:mosaic_gpu.so",
        "//jaxlib/mosaic/gpu/wheel:setup.py",
        "//jaxlib/mosaic/gpu/wheel:__init__.py",
        "//jaxlib:version",
    ],
)

jax_wheel(
    name = "mosaic_gpu_wheel_cuda12",
    enable_cuda = True,
    no_abi = True,
    platform_version = "12",
    source_files = [":mosaic_sources"],
    wheel_binary = ":build_mosaic_wheel_tool",
    wheel_name = "mosaic_gpu_cuda12",
)

jax_wheel(
    name = "mosaic_gpu_wheel_cuda13",
    enable_cuda = True,
    no_abi = True,
    platform_version = "13",
    source_files = [":mosaic_sources"],
    wheel_binary = ":build_mosaic_wheel_tool",
    wheel_name = "mosaic_gpu_cuda13",
)

# Wheel tests.

AARCH64_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "aarch64")])

PPC64LE_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "ppc64le")])

X86_64_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "x86_64")])

verify_manylinux_compliance_test(
    name = "jaxlib_manylinux_compliance_test",
    aarch64_compliance_tag = AARCH64_MANYLINUX_TAG,
    ppc64le_compliance_tag = PPC64LE_MANYLINUX_TAG,
    test_tags = [
        "manual",
    ],
    wheel = ":jaxlib_wheel",
    x86_64_compliance_tag = X86_64_MANYLINUX_TAG,
)

verify_manylinux_compliance_test(
    name = "jax_cuda_plugin_manylinux_compliance_test",
    aarch64_compliance_tag = AARCH64_MANYLINUX_TAG,
    ppc64le_compliance_tag = PPC64LE_MANYLINUX_TAG,
    test_tags = [
        "manual",
    ],
    wheel = ":jax_cuda{cuda}_plugin_wheel".format(cuda = cuda_major_version),
    x86_64_compliance_tag = X86_64_MANYLINUX_TAG,
)

verify_manylinux_compliance_test(
    name = "jax_cuda_pjrt_manylinux_compliance_test",
    aarch64_compliance_tag = AARCH64_MANYLINUX_TAG,
    ppc64le_compliance_tag = PPC64LE_MANYLINUX_TAG,
    test_tags = [
        "manual",
    ],
    wheel = ":jax_cuda{cuda}_pjrt_wheel".format(cuda = cuda_major_version),
    x86_64_compliance_tag = X86_64_MANYLINUX_TAG,
)

verify_manylinux_compliance_test(
    name = "mosaic_gpu_manylinux_compliance_test",
    aarch64_compliance_tag = AARCH64_MANYLINUX_TAG,
    ppc64le_compliance_tag = PPC64LE_MANYLINUX_TAG,
    test_tags = [
        "manual",
    ],
    wheel = ":mosaic_gpu_wheel_cuda{cuda}".format(cuda = cuda_major_version),
    x86_64_compliance_tag = X86_64_MANYLINUX_TAG,
)

pytype_test(
    name = "jaxlib_wheel_size_test",
    srcs = [":wheel_size_test.py"],
    args = [
        "--wheel-path=$(location :jaxlib_wheel)",
        "--max-size-mib=110",
    ],
    data = [":jaxlib_wheel"],
    main = "wheel_size_test.py",
    tags = [
        "manual",
        "notap",
    ],
)

pytype_test(
    name = "jax_cuda_plugin_wheel_size_test",
    srcs = [":wheel_size_test.py"],
    args = [
        "--wheel-path=$(location :jax_cuda{cuda}_plugin_wheel)".format(cuda = cuda_major_version),
        "--max-size-mib=20",
    ],
    data = [":jax_cuda{cuda}_plugin_wheel".format(cuda = cuda_major_version)],
    main = "wheel_size_test.py",
    tags = [
        "manual",
        "notap",
    ],
)

pytype_test(
    name = "jax_cuda_pjrt_wheel_size_test",
    srcs = [":wheel_size_test.py"],
    args = [
        "--wheel-path=$(location :jax_cuda{cuda}_pjrt_wheel)".format(cuda = cuda_major_version),
        "--max-size-mib=150",
    ],
    data = [":jax_cuda{cuda}_pjrt_wheel".format(cuda = cuda_major_version)],
    main = "wheel_size_test.py",
    tags = [
        "manual",
        "notap",
    ],
)

pytype_test(
    name = "mosaic_gpu_wheel_size_test",
    srcs = [":wheel_size_test.py"],
    args = [
        "--wheel-path=$(location :mosaic_gpu_wheel_cuda{cuda})".format(cuda = cuda_major_version),
        "--max-size-mib=40",
    ],
    data = [":mosaic_gpu_wheel_cuda{cuda}".format(cuda = cuda_major_version)],
    main = "wheel_size_test.py",
    tags = [
        "manual",
        "notap",
    ],
)
